-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#8282: Support non-4d tensor and fp32_dest_acc_en for moreh nllloss backward #8966
Conversation
|
||
ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } | ||
ALWI void REL() { release_dst(tt::DstMode::Half); } | ||
#include "debug/dprint.h" // required in all kernels using DPRINT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed the dprint header.
union { | ||
float f; | ||
uint32_t u; | ||
} one, zero; | ||
one.f = 1.0f; | ||
zero.f = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now use the Scalar
structure
for _ in range(2): | ||
tt_input_grad = ttl.operations.primary.moreh_nll_loss_backward( | ||
tt_target, | ||
tt_weight, | ||
tt_divisor, | ||
tt_output_grad, | ||
tt_input_grad, | ||
ignore_index, | ||
reduction_mean, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the point of this loop? If it's to test program cache:
- add an assert on program entries
- shift the input/output memory by adding some dummy tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the point of this loop? If it's to test program cache:
Yes, this is to test the program cache. If it runs only once, the code related to override_runtime_args_callback will not be executed.
add an assert on program entries
I didn't understand the previous statement.
There are asserts like:
TT_ASSERT(input_tensors.size() == 2);
TT_ASSERT(optional_input_tensors.size() == 2);
TT_ASSERT(output_tensors.size() == 1);
If not, can you provide an example?
shift the input/output memory by adding some dummy tensor
I've modified the callback test to receive random input every time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to test program cache:
- Ensure that program cache is actually hit. You do this by checking that the number of generated caches for an op is 1 (or however many your test generates) when you loop it twice. You can query the number of program caches with
device.num_program_cache_entries()
. - Ensure that the callback is actually correct. To properly test the callback, you have to make sure that the runtime arg updates actually matter in your test. Most often, the callback updates the inputs/output buffer addresses. Trivially looping the test most often results in your input/output tensors being in the same location for two runs (ie. same buffer addresses). If this is the case, the test will always pass when you don't have anything in the callback, even if your data is different the second time. To convince yourself, you can print out the value of the args being updated in the callback and see if it's the same as the first time the program was compiled and launched. If it is, then the test isn't really testing anything. One hack is to shift the device memory of where your inputs/output are expected to be by allocating a small tensor in between the loops (see below). As an exercise, you can again comment out the callback and now you will see the test actually fail the second time since your inputs/output are actually in a different location now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trivially looping the test most often results in your input/output tensors being in the same location for two runs (ie. same buffer addresses).
Thank you for the comment. As you mentioned, by running the for loop twice without tensor allocation, I confirmed that the tensor addresses remained the same. Therefore, I moved the tensor creation inside the for loop, and I observed that the addresses were changed. Since the buffer addresses change without a dummy tensor, it seems the dummy tensor is unnecessary.
This has benn fixed in 'f3c5ab9'
for _ in range(2):
# In each loop, a new tt tensor and value are created.
(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
if none_weight:
torch_weight = None
(tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
tt_loss = ttl.operations.primary.moreh_nll_loss(
tt_input,
tt_target,
tt_weight,
tt_divisor,
tt_output,
ignore_index,
reduction_mean,
)
device.num_program_cache_entries().
The above function appears to print the number of generated caches.
However, in the case of NLL loss
, it internally calls moreh_sum
, and the number of generated caches depends on the implementation of moreh_sum
. Therefore, checking this number might not be the correct approach.
Instead of the num_program_cache_entries
function, how about adding a boolean variable to ProgramCache
along with enable_cache_check()
and disable_cache_check()
functions, and then incorporating checks like TT_ASSERT(cache_hit)
?
struct ProgramCache {
inline std::optional<std::shared_ptr<void>> find(uint64_t program_hash) {
auto cache_hit = this->cache_.count(program_hash) > 0;
if (is_cache_check_enabled_ ) {
TT_ASSERT(cache_hit);
}
if (cache_hit) {
return this->cache_.at(program_hash);
}
return std::nullopt;
}
void enable_cache_check() {
is_cache_check_enabled_ = true;
}
void disable_cache_check() {
is_cache_check_enabled_ = false;
}
private:
inline static bool is_cache_check_enabled_ = false;
}
def test_callback()
...
for i in range(2):
if (i == 1)
# After enabling cache_check, if a cache miss occurs, an assertion is triggered.
device.enable_cache_check()
run_tt_op()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your method without dummy tensors probably also works because python doesn't deallocate the old tt tensors when it creates the next set.
I don't think we need to add anything new to ProgramCache
. It doesn't matter which implementation of moreh_sum
this test uses. You just have to assert against the number of expected caches.
...oreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp
Outdated
Show resolved
Hide resolved
...oreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp
Outdated
Show resolved
Hide resolved
1fb8d3f
to
d3b2857
Compare
refactoring moreh nllloss backward
Add moreh helper functions for fp32_dest_acc_en